import argparse, os, sys, glob
import cv2
import torch
import numpy as np
from collections import defaultdict
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm, trange
from imwatermark import WatermarkEncoder
from itertools import islice
from einops import rearrange
from torchvision.utils import make_grid
import time
import random
from pytorch_lightning import seed_everything
from torch import autocast
from contextlib import contextmanager, nullcontext
import torchvision


import os
import sys
current_dir = os.path.dirname(os.path.realpath(__file__))
sys.path.append(os.path.join(current_dir, "../../Paint-by-Example"))
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler

#from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor
import clip
from torchvision.transforms import Resize

from .vis_prior_generator import InpaintingVPG
from .utils import copy_pil_to_numpy


#safety_model_id = "CompVis/stable-diffusion-safety-checker"
#safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
#safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)


class InpaintingByExampleVPG(InpaintingVPG):
    def __init__(self, 
                 annotation, 
                 im_folder, 
                 config_path, 
                 ckpt_path, 
                 device="cuda", 
                 plms=True,  # TODO: is it better than ddlm?
                 zero_shot_prior=True,  # TODO: support retrieving ref image from annotation
                 ):
        
        self.prior_bank = defaultdict(list)
        self.annotation = annotation
        self.im_folder = im_folder
        self.zero_shot_prior = zero_shot_prior  # For ref image: if zero_shot_prior is False, use a bbox content; if it's True, use an image generated by SD

        self.update_prior_bank_with_images(annotation=annotation, im_folder=im_folder)

        # initialize diffusion model
        self.config = OmegaConf.load(f"{config_path}")
        self.model = load_model_from_config(self.config, f"{ckpt_path}")

        self.device = torch.device(device)
        self.model = self.model.to(self.device)

        if plms:
            self.sampler = PLMSSampler(self.model)
        else:
            self.sampler = DDIMSampler(self.model)

        if zero_shot_prior:
            from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
            #model_id = "stabilityai/stable-diffusion-1-5"
            model_id = "runwayml/stable-diffusion-v1-5"
            self.prior_pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
            self.prior_pipe.scheduler = DPMSolverMultistepScheduler.from_config(self.prior_pipe.scheduler.config)
            self.prior_pipe = self.prior_pipe.to("cuda")
        else:
            raise NotImplementedError  # TODO: add bbox content to prior dict

    def get_ref_image(self, cat):
        # prompt = f"a photo of a {cat}" 
        prompt = f"{cat}"
        image = self.prior_pipe(prompt).images[0]
        image = copy_pil_to_numpy(image)
        return image

    def sample_images_to_inpaint(self, num_images, pixel_size):
        # reuse sample_image_visual_priors
        samples = self.sample_image_visual_priors(num_layouts=num_images, pixel_size=pixel_size)
        
        images = []
        bboxes = []
        categories = []
        ref_images = []
        mask_images = []
        for i in tqdm(range(num_images)):
            sample = samples[i]
            if len(sample["bboxes"]) == 0:
                images.append(sample["vis_priors"][0])
                bboxes.append([])
                categories.append([])
                ref_images.append(np.zeros_like(sample["vis_priors"][0]))
                mask_images.append(np.zeros_like(sample["vis_priors"][0]))
                continue
            
            image = sample["vis_priors"][0]
            bbox_idx = random.choice(range(len(sample["bboxes"])))  # TODO: try more than 1 
            bbox = sample["bboxes"][bbox_idx]
            cat = sample["cats"][bbox_idx]
            ref_image = self.get_ref_image(cat=cat)

            mask_image = self.sample_an_image_mask(image, [bbox])

            images.append(image)
            bboxes.append([bbox])
            categories.append([cat])
            ref_images.append(ref_image)
            mask_images.append(mask_image)

        # bboxes and cats here are for inpainting, the gt info of the image is in samples
        return images, mask_images, bboxes, categories, ref_images, samples

    def inpaint_one_image(self,
                        img_p,  # PIL RGB Image
                        ref_p,  # PIL RGB Image
                        mask,  # PIL L Image
                        precision="autocast",
                        scale=5,
                        fixed_code=False,
                        n_samples=1,
                        C=4, # latent channels
                        H=512,
                        W=512,
                        f=8,  # downsampling factor
                        ddim_steps=50,
                        ddim_eta=0.,
                        ):
        start_code = None
        if fixed_code:
            start_code = torch.randn([n_samples, C, H // f, W // f], device=self.device)
        precision_scope = autocast if precision=="autocast" else nullcontext
        with torch.no_grad():
            with precision_scope("cuda"):
                with self.model.ema_scope():
                    image_tensor = get_tensor()(img_p)
                    image_tensor = image_tensor.unsqueeze(0)

                    ref_p = ref_p.resize((224,224))
                    ref_tensor=get_tensor_clip()(ref_p)
                    ref_tensor = ref_tensor.unsqueeze(0)

                    mask = np.array(mask)[None,None]
                    mask = 1 - mask.astype(np.float32)/255.0
                    mask[mask < 0.5] = 0
                    mask[mask >= 0.5] = 1
                    mask_tensor = torch.from_numpy(mask)

                    inpaint_image = image_tensor*mask_tensor
                    
                    test_model_kwargs={}
                    test_model_kwargs['inpaint_mask']=mask_tensor.to(self.device)
                    test_model_kwargs['inpaint_image']=inpaint_image.to(self.device)
                    ref_tensor=ref_tensor.to(self.device)

                    uc = None
                    if scale != 1.0:
                        uc = self.model.learnable_vector
                    c = self.model.get_learned_conditioning(ref_tensor.to(torch.float16))
                    c = self.model.proj_out(c)
                    inpaint_mask=test_model_kwargs['inpaint_mask']
                    z_inpaint = self.model.encode_first_stage(test_model_kwargs['inpaint_image'])
                    z_inpaint = self.model.get_first_stage_encoding(z_inpaint).detach()
                    test_model_kwargs['inpaint_image']=z_inpaint
                    test_model_kwargs['inpaint_mask']=Resize([z_inpaint.shape[-2],z_inpaint.shape[-1]])(test_model_kwargs['inpaint_mask'])

                    shape = [C, H // f, W // f]
                    samples_ddim, _ = self.sampler.sample(S=ddim_steps,
                                                        conditioning=c,
                                                        batch_size=n_samples,
                                                        shape=shape,
                                                        verbose=False,
                                                        unconditional_guidance_scale=scale,
                                                        unconditional_conditioning=uc,
                                                        eta=ddim_eta,
                                                        x_T=start_code,
                                                        test_model_kwargs=test_model_kwargs)

                    x_samples_ddim = self.model.decode_first_stage(samples_ddim)
                    x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
                    x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()

                    x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim)
                    x_checked_image=x_samples_ddim
                    x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)

                    def un_norm(x):
                        return (x+1.0)/2.0
                    def un_norm_clip(x):
                        x[0,:,:] = x[0,:,:] * 0.26862954 + 0.48145466
                        x[1,:,:] = x[1,:,:] * 0.26130258 + 0.4578275
                        x[2,:,:] = x[2,:,:] * 0.27577711 + 0.40821073
                        return x

                    grids = []
                    result_images = []
                    masks = []
                    gts = []
                    inpaints = []
                    refs = []

                    for i,x_sample in enumerate(x_checked_image_torch):
                        wm_encoder = None  # Do not use watermark

                        all_img=[]
                        all_img.append(un_norm(image_tensor[i]).cpu())
                        all_img.append(un_norm(inpaint_image[i]).cpu())
                        ref_img=ref_tensor
                        ref_img=Resize([H, W])(ref_img)
                        all_img.append(un_norm_clip(ref_img[i]).cpu())
                        all_img.append(x_sample)
                        grid = torch.stack(all_img, 0)
                        grid = make_grid(grid)
                        grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
                        grid = Image.fromarray(grid.astype(np.uint8))
                        grid = put_watermark(grid, wm_encoder)
                        #grid.save(os.path.join(grid_path, 'grid-'+filename[:-4]+'_'+str(opt.seed)+'.png'))
                        grids.append(grid)
                        


                        x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
                        img = Image.fromarray(x_sample.astype(np.uint8))
                        img = put_watermark(img, wm_encoder)
                        #img.save(os.path.join(result_path, filename[:-4]+'_'+str(opt.seed)+".png"))
                        result_images.append(img)
                        
                        mask_save=255.*rearrange(un_norm(inpaint_mask[i]).cpu(), 'c h w -> h w c').numpy()
                        mask_save= cv2.cvtColor(mask_save,cv2.COLOR_GRAY2RGB)
                        mask_save = Image.fromarray(mask_save.astype(np.uint8))
                        #mask_save.save(os.path.join(sample_path, filename[:-4]+'_'+str(opt.seed)+"_mask.png"))
                        masks.append(mask_save)

                        GT_img=255.*rearrange(all_img[0], 'c h w -> h w c').numpy()
                        GT_img = Image.fromarray(GT_img.astype(np.uint8))
                        #GT_img.save(os.path.join(sample_path, filename[:-4]+'_'+str(opt.seed)+"_GT.png"))
                        gts.append(GT_img)

                        inpaint_img=255.*rearrange(all_img[1], 'c h w -> h w c').numpy()
                        inpaint_img = Image.fromarray(inpaint_img.astype(np.uint8))
                        #inpaint_img.save(os.path.join(sample_path, filename[:-4]+'_'+str(opt.seed)+"_inpaint.png"))
                        inpaints.append(inpaint_img)

                        ref_img=255.*rearrange(all_img[2], 'c h w -> h w c').numpy()
                        ref_img = Image.fromarray(ref_img.astype(np.uint8))
                        #ref_img.save(os.path.join(sample_path, filename[:-4]+'_'+str(opt.seed)+"_ref.png"))
                        refs.append(ref_img)
        return grids, result_images, masks, gts, inpaints, refs
        

def chunk(it, size):
    it = iter(it)
    return iter(lambda: tuple(islice(it, size)), ())


def get_tensor_clip(normalize=True, toTensor=True):
    transform_list = []
    if toTensor:
        transform_list += [torchvision.transforms.ToTensor()]

    if normalize:
        transform_list += [torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
                                                (0.26862954, 0.26130258, 0.27577711))]
    return torchvision.transforms.Compose(transform_list)


def numpy_to_pil(images):
    """
    Convert a numpy image or a batch of images to a PIL image.
    """
    if images.ndim == 3:
        images = images[None, ...]
    images = (images * 255).round().astype("uint8")
    pil_images = [Image.fromarray(image) for image in images]

    return pil_images


def load_model_from_config(config, ckpt, verbose=False):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")
    if "global_step" in pl_sd:
        print(f"Global Step: {pl_sd['global_step']}")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    if len(m) > 0 and verbose:
        print("missing keys:")
        print(m)
    if len(u) > 0 and verbose:
        print("unexpected keys:")
        print(u)

    model.cuda()
    model.eval()
    return model


def put_watermark(img, wm_encoder=None):
    if wm_encoder is not None:
        img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
        img = wm_encoder.encode(img, 'dwtDct')
        img = Image.fromarray(img[:, :, ::-1])
    return img


def load_replacement(x):
    try:
        hwc = x.shape
        y = Image.open("assets/rick.jpeg").convert("RGB").resize((hwc[1], hwc[0]))
        y = (np.array(y)/255.0).astype(x.dtype)
        assert y.shape == x.shape
        return y
    except Exception:
        return x


def check_safety(x_image):
    #safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
    #x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
    #assert x_checked_image.shape[0] == len(has_nsfw_concept)
    #for i in range(len(has_nsfw_concept)):
    #    if has_nsfw_concept[i]:
    #        x_checked_image[i] = load_replacement(x_checked_image[i])
    
    # TODO: removed due to diffuser version conflicts: https://github.com/CompVis/stable-diffusion/issues/627
    return x_image, False


def get_tensor(normalize=True, toTensor=True):
    transform_list = []
    if toTensor:
        transform_list += [torchvision.transforms.ToTensor()]

    if normalize:
        transform_list += [torchvision.transforms.Normalize((0.5, 0.5, 0.5),
                                                (0.5, 0.5, 0.5))]
    return torchvision.transforms.Compose(transform_list)


def get_tensor_clip(normalize=True, toTensor=True):
    transform_list = []
    if toTensor:
        transform_list += [torchvision.transforms.ToTensor()]

    if normalize:
        transform_list += [torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
                                                (0.26862954, 0.26130258, 0.27577711))]
    return torchvision.transforms.Compose(transform_list)
